# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

#Linear Module to use with ZeRO Stage 3 to allow for parameter memory release
#after the module execution during forward
#Instead of saving variables using save_for_backward, we save variable ids
#Allowing us to retrieve the variable without creating pointer to it
#Which allows for underlying tensor to be garbage collected
#When partitioned as needed by the Zero Stage 3 optimizer
#TODO instead of patching Linear module, we could patch the ctx.save_for_backward
#ctx.saved_tensors so that this approach works for all nn modules that are built upon
#torch.nn.function. However the issue is that many modules uses C++ implementations
#which does not have pytorch implementation. Eg torch.addmm which acts as a functional
#when implemented outside of torch.autograd.Function

import math
import functools

import torch
from torch import Tensor
from torch.nn.parameter import Parameter
from torch.nn import init
from torch.nn.modules.module import Module
from deepspeed.runtime.utils import noop_decorator
from deepspeed import comm as dist
from deepspeed.accelerator import get_accelerator


def print_rank_0(message, debug=False, force=False):
    if dist.get_rank() == 0 and (debug or force):
        print(message)


try:
    # Fix `torch.[device].amp.custom_fwd/bwd` FutureWarning in torch 2.4
    if hasattr(torch, 'amp') and hasattr(torch.amp, 'custom_fwd') and hasattr(torch.amp, 'custom_bwd'):
        autocast_custom_fwd = functools.partial(torch.amp.custom_fwd, device_type=get_accelerator().device_name())
        autocast_custom_bwd = functools.partial(torch.amp.custom_bwd, device_type=get_accelerator().device_name())
    else:
        # original implementation
        autocast_custom_fwd = get_accelerator().amp().custom_fwd
        autocast_custom_bwd = get_accelerator().amp().custom_bwd
except (ImportError, AttributeError) as exp:
    autocast_custom_fwd = noop_decorator
    autocast_custom_bwd = noop_decorator


class LinearFunctionForZeroStage3(torch.autograd.Function):

    # Note that both forward and backward are @staticmethods
    @staticmethod
    @autocast_custom_fwd
    # bias is an optional argument
    def forward(ctx, input, weight, bias=None):

        ctx.save_for_backward(input, weight, bias)

        if input.dim() == 2 and bias is not None:
            # fused op is marginally faster
            ret = torch.addmm(bias, input, weight.t())
        else:
            output = input.matmul(weight.t())
            if bias is not None:
                output += bias
            ret = output

        return ret

    # This function has only a single output, so it gets only one gradient
    @staticmethod
    @autocast_custom_bwd
    def backward(ctx, grad_output):
        # This is a pattern that is very convenient - at the top of backward
        # unpack saved_tensors and initialize all gradients w.r.t. inputs to
        # None. Thanks to the fact that additional trailing Nones are
        # ignored, the return statement is simple even when the function has
        # optional inputs.
        input, weight, bias = ctx.saved_tensors

        grad_input = grad_weight = grad_bias = None

        #print(f"backward shaped grad_output {grad_output.shape}, input {input.shape}, weight {weight.shape} and bias {bias.shape if bias is not None else None}")
        # These needs_input_grad checks are optional and there only to
        # improve efficiency. If you want to make your code simpler, you can
        # skip them. Returning gradients for inputs that don't require it is
        # not an error.
        if ctx.needs_input_grad[0]:
            #print(f"Computing grad input weight {weight.shape} grad_output {grad_output.shape}")
            grad_input = grad_output.matmul(weight)
            #print(f"Computed grad input {grad_input.shape}")
        if ctx.needs_input_grad[1]:
            #print("Computing grad weight")
            dim = grad_output.dim()
            if dim > 2:
                grad_weight = grad_output.reshape(-1,
                                                  grad_output.shape[-1]).t().matmul(input.reshape(-1, input.shape[-1]))
            else:
                grad_weight = grad_output.t().matmul(input)
            #print(f"Computed grad weight grad_weight {grad_weight.shape}")
        if bias is not None and ctx.needs_input_grad[2]:
            #print("Computing grad bias")
            if dim > 2:
                grad_bias = grad_output.sum([i for i in range(dim - 1)])
            else:
                grad_bias = grad_output.sum(0)
            #print("Done computing grad bias")
            #print("needs bias")
        #print(f"backward shaped grad_input {grad_input.shape}, grad_weight {grad_weight.shape}, grad_bias {grad_bias.shape if grad_bias is not None else None}")
        return grad_input, grad_weight, grad_bias


def zero3_linear_wrap(input, weight, bias=None):
    if bias is None:
        return LinearFunctionForZeroStage3.apply(input, weight)
    else:
        return LinearFunctionForZeroStage3.apply(input, weight, bias)


class LinearModuleForZeroStage3(Module):
    r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b`.
    The weights are pre-transposed and stored as A^T instead of transposing during each
    forward. Memory savings proportional to the parameter size.

    Args:
        in_features: size of each input sample
        out_features: size of each output sample
        bias: If set to ``False``, the layer will not learn an additive bias.
            Default: ``True``

    Shape:
        - Input: :math:`(N, *, H_{in})` where :math:`*` means any number of
          additional dimensions and :math:`H_{in} = \text{in\_features}`
        - Output: :math:`(N, *, H_{out})` where all but the last dimension
          are the same shape as the input and :math:`H_{out} = \text{out\_features}`.

    Attributes:
        weight: the learnable weights of the module of shape
            :math:`(\text{out\_features}, \text{in\_features})`. The values are
            initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where
            :math:`k = \frac{1}{\text{in\_features}}`
        bias:   the learnable bias of the module of shape :math:`(\text{out\_features})`.
                If :attr:`bias` is ``True``, the values are initialized from
                :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
                :math:`k = \frac{1}{\text{in\_features}}`

    Examples::

        >>> m = nn.Linear(20, 30)
        >>> input = torch.randn(128, 20)
        >>> output = m(input)
        >>> print(output.size())
        torch.Size([128, 30])
    """
    __constants__ = ['in_features', 'out_features']
    in_features: int
    out_features: int
    weight: Tensor

    def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
        super(LinearModuleForZeroStage3, self).__init__()
        print("Building ZeRO module")
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.Tensor(out_features, in_features))
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self) -> None:
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)

    def forward(self, input: Tensor) -> Tensor:
        return LinearFunctionForZeroStage3.apply(input, self.weight, self.bias)

    def extra_repr(self) -> str:
        return 'in_features={}, out_features={}, bias={}'.format(self.in_features, self.out_features, self.bias
                                                                 is not None)
